"""
All rights reserved. 
Author: Yang SONG (songyangmri@gmail.com)
"""
import os, csv
import shutil
import pandas as pd
from traceback import format_exc
from PyQt5.QtWidgets import *
from PyQt5.QtCore import pyqtSignal

from BC.DataContainer.DataContainer import DataContainer
from BC.FeatureAnalysis.IndexDict import Index2Dict
from BC.FeatureAnalysis.FeatureSelector import FeatureSelector, LoadSelectInfo
from BC.FeatureAnalysis.Classifier import LoadModel
from BC.GUI.ModelPrediction import Ui_ReuseFaeModel
from BC.Func.Metric import EstimatePrediction

from BC.Utility.Constants import *
from BC.Visualization import DrawProbability, DrawBoxPlot, DrawViolinPlot, DrawCalibrationCurve
from BC.Visualization.DrawROCList import DrawROCList, DrawPRCurveList


class ModelPredictionForm(QWidget):
    close_signal = pyqtSignal(bool)

    def __init__(self):
        super().__init__()
        self.ui = Ui_ReuseFaeModel()
        self.ui.setupUi(self)

        self.dc = DataContainer()
        self._model_root = ''

        self.prediction = []
        self.label = []
        self.binary_metric = {}

        self.ui.buttonLoadTestFeature.clicked.connect(self.LoadTestFeature)
        self.ui.buttonLoadModel.clicked.connect(self.LoadFaeModel)
        self.ui.buttonPredict.clicked.connect(self.Predict)

        self.ui.comboCurve.addItems(PLOT_TYPE)
        self.ui.comboCurve.currentIndexChanged.connect(self.ShowCurve)
        self.ui.checkAutoCutoff.stateChanged.connect(self.ShowResult)
        self.ui.spinCutoff.valueChanged.connect(self.ShowResult)

        self.ui.buttonSave.clicked.connect(self.Save)

    def closeEvent(self, event):
        self.close_signal.emit(True)
        event.accept()

    def LoadTestFeature(self):
        dlg = QFileDialog()
        file_name, _ = dlg.getOpenFileName(self, 'Open CSV file', filter="csv files (*.csv)")
        if file_name:
            try:
                self.ui.tableResult.clear()
                self.ui.canvas.getFigure().clear()
                self.ui.canvas.draw()

                if self.dc.Load(file_name):
                    self.label = self.dc.GetLabel()
                    self.ui.lineTestFeatureMatrixLoader.setText(file_name)
                else:
                    QMessageBox().about(self, 'Error', 'Load Failed. May there is no Label')
            except Exception as e:
                QMessageBox().about(self, 'Error', format_exc())

    def LoadFaeModel(self):
        dlg = QFileDialog()
        dlg.setFileMode(QFileDialog.DirectoryOnly)
        dlg.setOption(QFileDialog.ShowDirsOnly)

        message_box = QMessageBox()
        if dlg.exec_():
            self._model_root = dlg.selectedFiles()[0]

            pipeline_info_path = os.path.join(self._model_root, 'pipeline_info.csv')
            if not os.path.exists(pipeline_info_path):
                message_box.about(self, 'File Error', 'The file pipeline_info does not exists')
                return

            self.ui.comboNormalizer.clear()
            self.ui.comboDimensionReduction.clear()
            self.ui.comboFeatureSelector.clear()
            self.ui.comboClassifier.clear()
            self.ui.spinBoxFeatureNumber.setValue(0)
            self.ui.tableResult.clear()
            self.ui.canvas.getFigure().clear()
            self.ui.canvas.draw()

            with open(pipeline_info_path, 'r', newline='') as csvfile:
                reader = csv.reader(csvfile)
                for row in reader:
                    if row[0] == 'Version': # 只能接收0.3.0之后的FAE
                        versions = row[1].split('.')
                        version_number = int(versions[0]) * 1e4 + int(versions[1]) * 1e2 + int(versions[2])
                        if version_number < 300:
                            message_box.about(self, '', 'The result generated by FAE must be equal to or larger than 0.3.0')
                            return
                    if row[0] == 'Normalizer':
                        self.ui.comboNormalizer.addItems(row[1:])
                    elif row[0] == 'DimensionReduction':
                        self.ui.comboDimensionReduction.addItems(row[1:])
                    elif row[0] == 'FeatureSelector':
                        self.ui.comboFeatureSelector.addItems(row[1:])
                    elif row[0] == 'FeatureNumber':
                        self.ui.spinBoxFeatureNumber.setMinimum(int(row[1]))
                        self.ui.spinBoxFeatureNumber.setMaximum(int(row[-1]))
                    elif row[0] == 'Classifier':
                        self.ui.comboClassifier.addItems(row[1:])

            self.ui.lineModelPath.setText(self._model_root)

    def Predict(self):
        self.binary_metric = {}

        norm_name = self.ui.comboNormalizer.currentText()
        reduce_name = self.ui.comboDimensionReduction.currentText()
        select_name = self.ui.comboFeatureSelector.currentText()
        feature_number = self.ui.spinBoxFeatureNumber.value()
        classifier_name = self.ui.comboClassifier.currentText()

        message = QMessageBox()
        index_dictor = Index2Dict()

        # Normalize Features
        norm_path = os.path.join(self._model_root, norm_name, '{}_normalization_training.csv'.format(norm_name))
        if not os.path.exists(norm_path):
            QMessageBox().about(self, '', '{} not exists'.format(norm_path))
            return
        # To select if existing the invalid features
        try:
            used_features = pd.read_csv(norm_path)['feature_name'].tolist()
            used_dc = FeatureSelector().SelectFeatureByName(self.dc, used_features)
        except Exception:
            message.about(self, 'Check the valid of the featrures.', format_exc())
            return

        normalizer = index_dictor.GetInstantByIndex(norm_name)
        normalizer.LoadInfo(norm_path)
        try:
            norm_dc = normalizer.Transform(used_dc)
        except ValueError as e:
            message.about(self, 'Normalization Wrong', format_exc())
            return

        # Dimension Reducer
        dr_folder = os.path.join(self._model_root, norm_name, reduce_name)
        if not os.path.exists(dr_folder):
            QMessageBox().about(self, '', '{} not exists'.format(dr_folder))
            return
        reducer = index_dictor.GetInstantByIndex(reduce_name)
        reducer.LoadInfo(dr_folder)
        dr_dc = reducer.Transform(norm_dc)

        # Feature Select
        fs_info_path = os.path.join(self._model_root, norm_name, reduce_name, '{}_{}'.format(select_name, feature_number),
                               'feature_select_info.csv')
        if not os.path.exists(fs_info_path):
            QMessageBox().about(self, '', '{} not exists'.format(fs_info_path))
            return
        _, selected_features = LoadSelectInfo(fs_info_path)
        selector = FeatureSelector()
        fs_dc = selector.SelectFeatureByName(dr_dc, selected_features)

        # Predict
        cls_path = os.path.join(self._model_root, norm_name, reduce_name,
                                  '{}_{}'.format(select_name, feature_number), classifier_name, 'model.pickle')
        if not os.path.exists(cls_path):
            QMessageBox().about(self, '', '{} not exists'.format(cls_path))
            return
        model = LoadModel(cls_path)
        array = fs_dc.GetArray()
        self.prediction = model.predict_proba(array)[:, 1]

        self.ShowResult()
        self.ShowCurve()

    def ShowResult(self):
        self.ui.spinCutoff.setEnabled(not self.ui.checkAutoCutoff.isChecked())

        if self.ui.checkAutoCutoff.isChecked():
            self.binary_metric = EstimatePrediction(self.prediction, self.dc.GetLabel())
        else:
            self.binary_metric = EstimatePrediction(self.prediction, self.dc.GetLabel(), cutoff=self.ui.spinCutoff.value())

        self.ui.tableResult.clear()
        self.ui.tableResult.setRowCount(len(self.binary_metric))
        self.ui.tableResult.setColumnCount(2)

        for index, (key, value) in enumerate(self.binary_metric.items()):
            self.ui.tableResult.setItem(index, 0, QTableWidgetItem(key))
            self.ui.tableResult.setItem(index, 1, QTableWidgetItem(str(value)))

        self.ui.comboCurve.setCurrentText(PROBABILITY)
        self.ShowCurve()

    def ShowCurve(self):
        method = self.ui.comboCurve.currentText()
        if method == ROC_CURVE:
            DrawROCList([self.prediction], [self.label], is_show=False, fig=self.ui.canvas.getFigure())
        elif method == PR_CURVE:
            DrawPRCurveList([self.prediction], [self.label], is_show=False, fig=self.ui.canvas.getFigure())
        elif method == PROBABILITY:
            DrawProbability(self.prediction, self.label, cut_off=float(self.binary_metric[CUTOFF]), fig=self.ui.canvas.getFigure())
        elif method == CALIBRATION_CURVE:
            DrawCalibrationCurve(self.prediction, self.label, fig=self.ui.canvas.getFigure())
        elif method == BOX_PLOT:
            DrawBoxPlot(self.prediction, self.label, fig=self.ui.canvas.getFigure())
        elif method == VIOLIN_PLOT:
            DrawViolinPlot(self.prediction, self.label, fig=self.ui.canvas.getFigure())
        else:
            raise KeyError('Not existed method')

        self.ui.canvas.draw()

    def Save(self):
        dlg = QFileDialog()
        dlg.setFileMode(QFileDialog.DirectoryOnly)
        dlg.setOption(QFileDialog.ShowDirsOnly)

        if dlg.exec_():
            store_folder = dlg.selectedFiles()[0]
            if len(os.listdir(store_folder)) > 0:
                reply = QMessageBox.question(self, 'Continue?',
                                             'The folder is not empty, if you click Yes, the data would be over-written in this folder',
                                             QMessageBox.Yes, QMessageBox.No)
                if reply == QMessageBox.Yes:
                    try:
                        for file in os.listdir(store_folder):
                            if os.path.isdir(os.path.join(store_folder, file)):
                                shutil.rmtree(os.path.join(store_folder, file))
                            else:
                                os.remove(os.path.join(store_folder, file))
                    except PermissionError:
                        QMessageBox().about(self, 'Warning', 'Is there any opened files?')
                        return
                    except OSError:
                        QMessageBox().about(self, 'Warning', 'Is there any opened files?')
                        return

            # Store the prediction and the related figures
            if len(self.binary_metric) > 0:
                pred_df = pd.DataFrame({'prediction': self.prediction, 'label': self.label}, index=self.dc.GetCaseName())
                pred_df.to_csv(os.path.join(store_folder, 'prediction.csv'))
                metric_df = pd.DataFrame(self.binary_metric, index=['Metric'])
                metric_df.to_csv(os.path.join(store_folder, 'metric.csv'))

                DrawROCList([self.prediction], [self.label], store_path=os.path.join(store_folder, 'ROC.jpg'), is_show=False)
                DrawPRCurveList([self.prediction], [self.label], store_path=os.path.join(store_folder, 'PR-ROC.jpg'), is_show=False)
                DrawProbability(self.prediction, self.label, cut_off=float(self.binary_metric[CUTOFF]), store_path=os.path.join(store_folder, 'probability.jpg'))
                DrawCalibrationCurve(self.prediction, self.label, store_path=os.path.join(store_folder, 'calibration.jpg'))
                DrawBoxPlot(self.prediction, self.label, store_path=os.path.join(store_folder, 'boxplot.jpg'))
                DrawViolinPlot(self.prediction, self.label, store_path=os.path.join(store_folder, 'violinplot.jpg'))

                os.system("explorer.exe {:s}".format(os.path.normpath(store_folder)))


if __name__ == '__main__':
    import sys
    app = QApplication(sys.argv)
    frame = ModelPredictionForm()
    frame.show()
    sys.exit(app.exec_())